New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MRG] 32/64-bit float consistency with BernoulliRBM #16352
[MRG] 32/64-bit float consistency with BernoulliRBM #16352
Conversation
|
||
# dtype_in and dtype_out consistent | ||
assert Xt.dtype == dtype_out, ('transform dtype: {} - original dtype: {}' | ||
.format(Xt.dtype, X.dtype)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
While #16290 is not merged, we also should add a check that the results of fit_transform
are close enough with float32 and float64 input.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We also should check that the attributes of the estimator are close on float32 and float64 inputs, which I think is better to be done in the individual tests.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just added a dedicated test for that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this more susceptible to loss of precision in intermediate matrix multiplications?
I guess this question could apply to any neural net architecture, and yet all DL libraries use float32 and sometimes even floa16 lately. Maybe they just don't enforce that outputs in f64/f32 are identical? Not sure. If so, I wonder whether we should. Possibly checking that the convergence critera was reached no matter what it was is enough (at least for some type of algorithms). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm. Just a small comment. Also please add a what's new entry
assert_almost_equal(Xt_64, Xt_32, 6) | ||
assert_almost_equal(rbm_64.intercept_hidden_, | ||
rbm_32.intercept_hidden_, | ||
6) | ||
assert_almost_equal(rbm_64.intercept_visible_, | ||
rbm_32.intercept_visible_, | ||
6) | ||
assert_almost_equal(rbm_64.components_, rbm_32.components_, 6) | ||
assert_almost_equal(rbm_64.h_samples_, rbm_32.h_samples_, 0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please use assert_allclose
instead
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
sklearn/neural_network/_rbm.py
Outdated
@@ -146,8 +146,10 @@ def _mean_hiddens(self, v): | |||
h : ndarray of shape (n_samples, n_components) | |||
Corresponding mean field values for the hidden layer. | |||
""" | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sklearn/neural_network/_rbm.py
Outdated
p = safe_sparse_dot(v, self.components_.T) | ||
p += self.intercept_hidden_ | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks!
doc/whats_new/v0.23.rst
Outdated
@@ -235,6 +235,10 @@ Changelog | |||
:class:`neural_network.MLPClassifier` by clipping the probabilities. | |||
:pr:`16117` by `Thomas Fan`_. | |||
|
|||
- |Enhancement| Prevent the transformer from converting float32 to float64 in |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe "Avoid converting float32 input to float64 in ..."?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Much better ^^
Thanks @Henley13 ! |
Reference Issues/PRs
Works on #11000 for BernoulliRBM.
What does this implement/fix? Explain your changes.
Prevent the transformer from converting float32 to float64.
Any other comments?
Maybe we should wait to merge a generic test for dtype consistency (see #16290) before merging this one.